#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Dec  5 12:55:23 2023

@author: njr8582
"""
###This script replicates Figure C1 in the online appendix. It first recreates the left-panel and then the right panel.

###Load packages
import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import solve_ivp

###Define primitives
#For country i
k_i = 1.2#
a_i = 0.3 
w_bar_i = 0
w_min_i = -0.8 

#For country j
k_j = 1.2
a_j = 0.3
w_bar_j = 0
w_min_j = -0.8 

#Introducing Sunk Costs Requires that we check for when the the most resolved type of each country will want to transition to the first screening phase
T_p1 =  (1-w_bar_i)*(1/a_i + 1/k_i) - 1/a_i 

###We now start solving for the equilibrium.

##First we want to find the rate at which FBeta_i decreases with time
#Assuming a uniform CDF, this is the denominator for all calculations of the cdf
range_i = w_bar_i - w_min_i
range_j = w_bar_j - w_min_j

def betaipeace(t):
    return (-a_i*t - w_min_i)/range_i

def betajpeace(t):
    return (-a_j*t - w_min_j)/range_j

###The following function returns the CDF for how many types have conceded by time t if there is no
#  mass of types that concede at time 0 and countries play according to their concession rates in the peaceful phase.

def HDC_i(t):
    hdc = 1 - (1/(1+a_j*t))**(((a_j+k_j)/a_j))
    return hdc

def HDC_j(t):
    hdc = 1 - (1/(1+a_i*t))**(((a_i+k_i)/a_i))
    return hdc

#Before proceeding any further it is useful to circumscribe ourselves to the set of times that are realistic.
#We'll restrict ourselves to when F_beta_ would hit w_min because audience costs grew too large
t_max_i = (-w_min_i)/a_i 
t_max_j = (-w_min_j)/a_j

t_max = min(t_max_i, t_max_j)

t = np.linspace(0, t_max, 1000)

#Now to find the intersection of the two functions
y1i = betaipeace(t)
y2i = HDC_i(t)

#Next we move to calculating the screening phase strategies

def ssp_strategies(t_ssp, F): #where F is an array comprised of the following (F_i(0),F_i(1),F_j(0),F_j(1))
    min_i_ssp = F[0]*range_i + w_min_i
    max_i_ssp = F[1]*range_i + w_min_i
    min_j_ssp = F[2]*range_j + w_min_j
    max_j_ssp = F[3]*range_j + w_min_j
    f_i0 = k_i*(F[3] - F[2])/(1-max_j_ssp)
    f_i1_numerator = (k_j*(max_j_ssp + a_j*t_ssp) -a_j*(1-max_j_ssp))*(F[3]-F[2])
    f_i1_denomenator = (1-max_j_ssp)*(min_j_ssp +a_j*t_ssp)
    f_i1 = f_i1_numerator/f_i1_denomenator
    f_j0 = k_i*(F[1] - F[0])/(1-max_i_ssp)
    f_j1_numerator = (k_i*(max_i_ssp + a_i*t_ssp) -a_i*(1-max_i_ssp))*(F[1]-F[0])
    f_j1_denomenator = (1-max_i_ssp)*(min_i_ssp +a_i*t_ssp)
    f_j1 = f_j1_numerator/f_j1_denomenator
    return [f_i0,f_i1,f_j0,f_j1]

###Remains to set the time bounds we're interested in and the initial conditions for F's
t_element_ssp = np.argwhere(t== min(t[t>T_p1])) #search for the minimum element greater than T2, than return the number of that element
teval_ssp = t[t_element_ssp[0,0]:1001]
t_span_ssp = np.append(T_p1, np.array([t_max_j])) #The time is from the transition to the second screening phase onwards

F_i0_ssp_initial = HDC_i(T_p1)
F_i1_ssp_initial = 1
F_j0_ssp_initial = HDC_i(T_p1)
F_j1_ssp_initial = 1 

sln_ssp = solve_ivp(ssp_strategies, t_span_ssp, [F_i0_ssp_initial,F_i1_ssp_initial, F_j0_ssp_initial, F_j1_ssp_initial], t_eval=teval_ssp)

t_ssp =sln_ssp.t #the set of times for which there is a solution to the ODE
sigma_i0_ssp = sln_ssp.y[0] #i's ssp concession strategy
sigma_i1_ssp = sln_ssp.y[1] #i's ssp escalation strategy
sigma_j0_ssp = sln_ssp.y[2] #j's ssp concession strategy
sigma_j1_ssp = sln_ssp.y[3] #j's ssp escalation strategy

#We now need to find the end date
y1i_ssp = betaipeace(t_ssp)
idxi_ssp = np.argwhere(np.diff(np.sign(y1i_ssp - sigma_i0_ssp))).flatten()

Tbar = t_ssp[idxi_ssp]
Tbar = Tbar[0]

#Lets make the figure
fig, (ax1, ax2) = plt.subplots(1, 2)
fig.suptitle(r'Introducing Sunk Costs ($\bar{w}_i = 0$)')
ax1.title.set_text(r'$k_i = 1.2$')
ax2.title.set_text(r'$k_i = 3$')

T1 = T_p1

#We start by setting the limits for the figure, which are the same for both countries
endpoint_fig = Tbar + Tbar*0.1 #the range of x in the figure should depend on how long it takes to get to the horizon date 
endpoint_q = T1/endpoint_fig #used bellow to plot Q_i(t)
endpoint_Tbar = Tbar/endpoint_fig
ax1.set_xlim(0, endpoint_fig) #the x-axis is variable
ax1.set_ylim(w_min_i, (w_bar_i+0.2)) #the y-axis is not


#Then we move onto plotting the peaceful phase components
Beta_i = HDC_i(T1)*range_i + w_min_i #Because the countries are playing mixed strategies, it's only necessary to draw a line
y_range_i = range_i + 0.2


#Start by Plotting the Peaceful Phase
ax1.axhline(y = Beta_i, xmin = 0, xmax = endpoint_q) #Plotting Q_i(t)
ax1.axvline(x = T1, ymin = 0, ymax = ((Beta_i - w_min_i)/y_range_i), ls = '--') #Dashed line when the concessions are done
ax1.fill_between(x = [0,T1], y1 = Beta_i, y2=w_min_i, facecolor = 'none', edgecolor = 'b' ,hatch = '//')
ax1.annotate(r'$\beta_i^p$', xy = (0,Beta_i) , xytext = ((T1/9),(Beta_i+y_range_i*0.1)), arrowprops=dict(arrowstyle = 'simple'))
ax1.axvline(x = T1, ymin = ((Beta_i - w_min_i)/y_range_i), ymax = ((w_bar_i - w_min_i)/y_range_i), ls = '--', color = 'k')
ax1.annotate(r'T$_p$', xy = ((T1-T1/15),(w_bar_j+0.05)))

#Now we'll plot the second screening phase strategies
enddate = np.where(t_ssp == Tbar)
enddate = enddate[0]
enddate = enddate[0]
t_ssp_fig = t_ssp[:enddate]
sigma_i0_ssp_fig = sigma_i0_ssp[:enddate]
sigma_i1_ssp_fig = sigma_i1_ssp[:enddate]
sigma_j0_ssp_fig = sigma_j0_ssp[:enddate]
sigma_j1_ssp_fig = sigma_j1_ssp[:enddate]

ax1.plot(t_ssp_fig, (sigma_i0_ssp_fig *range_i + w_min_i))
ax1.plot(t_ssp_fig, (sigma_i1_ssp_fig *range_i + w_min_i) , color = 'r')


#Then we move onto plotting the strategies the mass of types going to war
#Country 1
w_i_max_Tbar = sigma_i1_ssp[enddate]*range_i + w_min_i
w_i_min_Tbar = sigma_i0_ssp[enddate]*range_i + w_min_i
ax1.axvline(x = Tbar, ymin = ((w_i_min_Tbar-w_min_i)/y_range_i), ymax = ((w_i_max_Tbar-w_min_i)/(y_range_i)), color = 'r') #Plotting wars at the horizon date
ax1.axhline(y = w_bar_i , xmin = 0, xmax = endpoint_q, ls = '--', color = 'r') #Dashed line up to the point where war begins
ax1.annotate(r'$\widebar{w}_j$', xy = (0,w_bar_i) , xytext = ((T1/9),(w_bar_i -y_range_i*0.1)), 
             arrowprops=dict(arrowstyle = 'simple', facecolor = 'red'))

ax1.axvline(x = Tbar, ymin = ((w_i_max_Tbar - w_min_i)/y_range_i), ymax = ((w_bar_i - w_min_i)/y_range_i), ls = '--', color = 'k')
ax1.axvline(x = Tbar, ymin = (0/y_range_i), ymax = ((w_i_min_Tbar - w_min_i)/y_range_i), ls = '--', color = 'k')
ax1.annotate(r'$\bar{T}$', xy = ((Tbar),(w_bar_i+0.04)))

###Define primitives
#For country i
k_i = 3
#For country j
k_j = 3

#Recalculate when the transition to screening happens
T_p1 =  (1-w_bar_i)*(1/a_i + 1/k_i) - 1/a_i 

###Resolving for equilibrium requires that we recalculate the screening phase strategies
def ssp_strategies(t_ssp, F): #where F is an array comprised of the following (F_i(0),F_i(1),F_j(0),F_j(1))
    min_i_ssp = F[0]*range_i + w_min_i
    max_i_ssp = F[1]*range_i + w_min_i
    min_j_ssp = F[2]*range_j + w_min_j
    max_j_ssp = F[3]*range_j + w_min_j
    f_i0 = k_i*(F[3] - F[2])/(1-max_j_ssp)
    f_i1_numerator = (k_j*(max_j_ssp + a_j*t_ssp) -a_j*(1-max_j_ssp))*(F[3]-F[2])
    f_i1_denomenator = (1-max_j_ssp)*(min_j_ssp +a_j*t_ssp)
    f_i1 = f_i1_numerator/f_i1_denomenator
    f_j0 = k_i*(F[1] - F[0])/(1-max_i_ssp)
    f_j1_numerator = (k_i*(max_i_ssp + a_i*t_ssp) -a_i*(1-max_i_ssp))*(F[1]-F[0])
    f_j1_denomenator = (1-max_i_ssp)*(min_i_ssp +a_i*t_ssp)
    f_j1 = f_j1_numerator/f_j1_denomenator
    return [f_i0,f_i1,f_j0,f_j1]

###Remains to set the time bounds we're interested in and the initial conditions for F's
t_element_ssp = np.argwhere(t== min(t[t>T_p1])) #search for the minimum element greater than T2, than return the number of that element
teval_ssp = t[t_element_ssp[0,0]:1001]#xk t[idx[0]:1001] 
t_span_ssp = np.append(T_p1, np.array([t_max_j])) #The time is from the transition to the second screening phase onwards

F_i0_ssp_initial = HDC_i(T_p1)
F_i1_ssp_initial = 1
F_j0_ssp_initial = HDC_i(T_p1)
F_j1_ssp_initial = 1 

sln_ssp = solve_ivp(ssp_strategies, t_span_ssp, [F_i0_ssp_initial,F_i1_ssp_initial, F_j0_ssp_initial, F_j1_ssp_initial], t_eval=teval_ssp)

t_ssp =sln_ssp.t #the set of times for which there is a solution to the ODE
sigma_i0_ssp = sln_ssp.y[0] #i's ssp concession strategy
sigma_i1_ssp = sln_ssp.y[1] #i's ssp escalation strategy
sigma_j0_ssp = sln_ssp.y[2] #j's ssp concession strategy
sigma_j1_ssp = sln_ssp.y[3] #j's ssp escalation strategy

#We now need to find the end date
y1i_ssp = betaipeace(t_ssp)
idxi_ssp = np.argwhere(np.diff(np.sign(y1i_ssp - sigma_i0_ssp))).flatten()

Tbar = t_ssp[idxi_ssp]
Tbar = Tbar[0]
T1 = T_p1

#We start by setting the limits for the figure, which are the same for both countries
endpoint_fig = Tbar + Tbar*0.1 #the range of x in the figure should depend on how long it takes to get to the horizon date 
endpoint_q = T1/endpoint_fig #used bellow to plot Q_i(t)
endpoint_Tbar = Tbar/endpoint_fig
ax2.set_xlim(0, endpoint_fig) #the x-axis is variable
ax2.set_ylim(w_min_j, (w_bar_j+0.2)) #the y-axis is not

#Then we move onto plotting the peaceful phase components
Beta_i = HDC_i(T1)*range_i + w_min_i #Because the countries are playing mixed strategies, it's only necessary to draw a line
y_range_i = range_i + 0.2


ax2.axhline(y = Beta_i, xmin = 0, xmax = endpoint_q) #Plotting Q_i(t)
ax2.axvline(x = T1, ymin = 0, ymax = ((Beta_i - w_min_i)/y_range_i), ls = '--') #Dashed line when the concessions are done
ax2.fill_between(x = [0,T1], y1 = Beta_i, y2=w_min_i, facecolor = 'none', edgecolor = 'b' ,hatch = '//')
ax2.annotate(r'$\beta_i^p$', xy = (0,Beta_i) , xytext = ((T1/9),(Beta_i+y_range_i*0.1)), arrowprops=dict(arrowstyle = 'simple'))
ax2.axvline(x = T1, ymin = ((Beta_i - w_min_i)/y_range_i), ymax = ((w_bar_i - w_min_i)/y_range_i), ls = '--', color = 'k')
ax2.annotate(r'T$_p$', xy = ((T1-T1/15),(w_bar_i+0.05)))

#Now we'll plot the second screening phase strategies
enddate = np.where(t_ssp == Tbar)
enddate = enddate[0]
enddate = enddate[0]
t_ssp_fig = t_ssp[:enddate]
sigma_i0_ssp_fig = sigma_i0_ssp[:enddate]
sigma_i1_ssp_fig = sigma_i1_ssp[:enddate]

ax2.plot(t_ssp_fig, (sigma_i0_ssp_fig*range_i + w_min_i))
ax2.plot(t_ssp_fig, (sigma_i1_ssp_fig*range_i + w_min_i) , color = 'r')

#Then we move onto plotting the strategies the mass of types going to war
w_i_max_Tbar = sigma_i1_ssp[enddate]*range_i + w_min_i
w_i_min_Tbar = sigma_i0_ssp[enddate]*range_i + w_min_i
ax2.axvline(x = Tbar, ymin = ((w_i_min_Tbar-w_min_i)/y_range_i), ymax = ((w_i_max_Tbar-w_min_i)/(y_range_i)), color = 'r') #Plotting wars at the horizon date
ax2.axhline(y = w_bar_i , xmin = 0, xmax = endpoint_q, ls = '--', color = 'r') #Dashed line up to the point where war begins
ax2.annotate(r'$\widebar{w}_i$', xy = (0,w_bar_i) , xytext = ((T1/9),(w_bar_i -y_range_i*0.1)), 
             arrowprops=dict(arrowstyle = 'simple', facecolor = 'red'))

ax2.axvline(x = Tbar, ymin = ((w_i_max_Tbar - w_min_i)/y_range_i), ymax = ((w_bar_i - w_min_i)/y_range_i), ls = '--', color = 'k')
ax2.axvline(x = Tbar, ymin = (0/y_range_i), ymax = ((w_i_min_Tbar - w_min_i)/y_range_i), ls = '--', color = 'k')
ax2.annotate(r'$\bar{T}$', xy = ((Tbar-Tbar/17),(w_bar_i+0.04)))

plt.show()
